{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3.5.5 完整神经网络样例程序\n",
    "可以分为三个步骤：\n",
    "1. 定义神经网络参数，输入和输出节点；\n",
    "2. 定义前向传播、损失函数和反向传播优化算法；\n",
    "3. 生成会话，并在训练集上反复运行反向传播优化算法。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from numpy.random import RandomState\n",
    "\n",
    "# 生成模拟数据集\n",
    "rdm = RandomState(1)\n",
    "X = rdm.rand(128, 2)\n",
    "Y = [[int(x1 + x2 < 1)] for (x1, x2) in X]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 1. 定义神经网络的参数，输入和输出节点。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 8  # 每一个batch中样本的数量\n",
    "STEPS = 5000    # NN运行迭代的轮数\n",
    "\n",
    "w1= tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1))\n",
    "w2= tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1))\n",
    "\n",
    "x = tf.placeholder(tf.float32, shape=(None, 2), name=\"x-input\")\n",
    "y_= tf.placeholder(tf.float32, shape=(None, 1), name='y-input')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 2. 定义前向传播过程，损失函数及反向传播算法。\n",
    "损失函数和反向传播第四章会有详述。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = tf.matmul(x, w1)\n",
    "y = tf.matmul(a, w2)\n",
    "y = tf.sigmoid(y)\n",
    "\n",
    "cross_entropy = -tf.reduce_mean(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))\n",
    "                                + (1 - y_) * tf.log(tf.clip_by_value(1 - y, 1e-10, 1.0)))\n",
    "train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 3. 创建一个会话来运行TensorFlow程序。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-0.8113182   1.4845988   0.06532937]\n",
      " [-2.4427042   0.0992484   0.5912243 ]]\n",
      "[[-0.8113182 ]\n",
      " [ 1.4845988 ]\n",
      " [ 0.06532937]]\n",
      "\n",
      "\n",
      "After 0 training step(s), cross entropy on all data is 1.89805\n",
      "After 1000 training step(s), cross entropy on all data is 0.655075\n",
      "After 2000 training step(s), cross entropy on all data is 0.626172\n",
      "After 3000 training step(s), cross entropy on all data is 0.615096\n",
      "After 4000 training step(s), cross entropy on all data is 0.610309\n",
      "\n",
      "\n",
      "[[ 0.02476974  0.56948686  1.6921943 ]\n",
      " [-2.1977353  -0.23668927  1.1143897 ]]\n",
      "[[-0.45544702]\n",
      " [ 0.49110925]\n",
      " [-0.98110336]]\n"
     ]
    }
   ],
   "source": [
    "with tf.Session() as sess:\n",
    "    init_op = tf.global_variables_initializer()\n",
    "    sess.run(init_op)\n",
    "    \n",
    "    # 输出目前（未经训练）的参数取值。\n",
    "    print(sess.run(w1))\n",
    "    print(sess.run(w2))\n",
    "    print(\"\\n\")\n",
    "    \n",
    "    # 训练模型。\n",
    "    for i in range(STEPS):\n",
    "        start = (i * BATCH_SIZE) % 128\n",
    "        end = (i * BATCH_SIZE) % 128 + BATCH_SIZE\n",
    "        sess.run([train_step, y, y_], feed_dict={x:X[start:end], y_:Y[start:end]})\n",
    "        if i % 1000 == 0:\n",
    "            total_cross_entropy = sess.run(cross_entropy, feed_dict={x:X, y_:Y})\n",
    "            print(\"After %d training step(s), cross entropy on all data is %g\" % (i, total_cross_entropy))\n",
    "    \n",
    "    # 输出训练后的参数取值。\n",
    "    print(\"\\n\")\n",
    "    print(sess.run(w1))\n",
    "    print(sess.run(w2))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
